{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Working with the python DyNet package\n",
    "\n",
    "The DyNet package is intended for training and using neural networks, and is particularly suited for applications with dynamically changing network structures. It is a python-wrapper for the DyNet C++ package.\n",
    "\n",
    "In neural network packages there are generally two modes of operation:\n",
    "\n",
    "* __Static networks__, in which a network is built and then being fed with different inputs/outputs. Most NN packages work this way.\n",
    "* __Dynamic networks__, in which a new network is built for each training example (sharing parameters with the networks of other training examples).  This approach is what makes DyNet unique, and where most of its power comes from.\n",
    "\n",
    "We will describe both of these modes."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Package Fundamentals\n",
    "\n",
    "The main piece of DyNet is the `ComputationGraph`, which is what essentially defines a neural network.\n",
    "The `ComputationGraph` is composed of expressions, which relate to the inputs and outputs of the network,\n",
    "as well as the `Parameters` of the network. The parameters are the things in the network that are optimized over time, and all of the parameters sit inside a `ParameterCollection`. There are `trainers` (for example `SimpleSGDTrainer`) that are in charge of setting the parameter values.\n",
    "\n",
    "We will not be using the `ComputationGraph` directly, but it is there in the background, as a singleton object.\n",
    "When `dynet` is imported, a new `ComputationGraph` is created. We can then reset the computation graph to a new state\n",
    "by calling `renew_cg()`.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Static Networks\n",
    "\n",
    "The life-cycle of a DyNet program is:\n",
    "1. Create a `ParameterCollection`, and populate it with `Parameters`.\n",
    "2. Renew the computation graph, and create `Expression` representing the network\n",
    "      (the network will include the `Expression`s for the `Parameters` defined in the parameter collection).\n",
    "3. Optimize the model for the objective of the network.\n",
    "\n",
    "As an example, consider a model for solving the \"xor\" problem. The network has two inputs, which can be 0 or 1, and a single output which should be the xor of the two inputs.\n",
    "We will model this as a multi-layer perceptron with a single hidden layer.\n",
    "\n",
    "Let $x = x_1, x_2$ be our input. We will have a hidden layer of 8 nodes, and an output layer of a single node. The activation on the hidden layer will be a $\\tanh$. Our network will then be:\n",
    "\n",
    "$\\sigma(V(\\tanh(Wx+b)))$\n",
    "\n",
    "Where $W$ is a $8 \\times 2$ matrix, $V$ is an $8 \\times 1$ matrix, and $b$ is an 8-dim vector.\n",
    "\n",
    "We want the output to be either 0 or 1, so we take the output layer to be the logistic-sigmoid function, $\\sigma(x)$, that takes values between $-\\infty$ and $+\\infty$ and returns numbers in $[0,1]$.\n",
    "\n",
    "We will begin by defining the model and the computation graph.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we assume that we have the dynet module in your path.\n",
    "import dynet as dy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<_dynet.ComputationGraph at 0x7f79486546c0>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# create a parameter collection and add the parameters.\n",
    "m = dy.ParameterCollection()\n",
    "W = m.add_parameters((8,2))\n",
    "V = m.add_parameters((1,8))\n",
    "b = m.add_parameters((8))\n",
    "\n",
    "dy.renew_cg() # new computation graph. not strictly needed here, but good practice.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.31104135513305664,\n",
       " -0.36465519666671753,\n",
       " -0.43395277857780457,\n",
       " 0.5421143770217896,\n",
       " -0.3137839138507843,\n",
       " 0.16922643780708313,\n",
       " 0.3162959814071655,\n",
       " -0.08413488417863846]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#b[1:-1].value()\n",
    "b.value()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The first block creates a parameter collection and populates it with parameters.\n",
    "The second block creates a computation graph.\n",
    "\n",
    "The model parameters can be used as expressions in the computation graph. We now make use of V, W, and b in order to create the complete expression for the network."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = dy.vecInput(2) # an input vector of size 2. Also an expression.\n",
    "output = dy.logistic(V*(dy.tanh((W*x)+b)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.41870421171188354"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# we can now query our network\n",
    "x.set([0,0])\n",
    "output.value()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# we want to be able to define a loss, so we need an input expression to work against.\n",
    "y = dy.scalarInput(0) # this will hold the correct answer\n",
    "loss = dy.binary_log_loss(output, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5666957497596741\n",
      "0.837935209274292\n"
     ]
    }
   ],
   "source": [
    "x.set([1,0])\n",
    "y.set(0)\n",
    "print(loss.value())\n",
    "\n",
    "y.set(1)\n",
    "print(loss.value())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training\n",
    "We now want to set the parameter weights such that the loss is minimized. \n",
    "\n",
    "For this, we will use a trainer object. A trainer is constructed with respect to the parameters of a given model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = dy.SimpleSGDTrainer(m)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To use the trainer, we need to:\n",
    "* **call the `forward_scalar`** method of `ComputationGraph`. This will run a forward pass through the network, calculating all the intermediate values until the last one (`loss`, in our case), and then convert the value to a scalar. The final output of our network **must** be a single scalar value. However, if we do not care about the value, we can just use `cg.forward()` instead of `cg.forward_sclar()`.\n",
    "* **call the `backward`** method of `ComputationGraph`. This will run a backward pass from the last node, calculating the gradients with respect to minimizing the last expression (in our case we want to minimize the loss). The gradients are stored in the parameter collection, and we can now let the `trainer` take care of the optimization step.\n",
    "* **call `trainer.update()`** to optimize the values with respect to the latest gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "the loss before step is: 0.837935209274292\n",
      "the loss after step is: 0.6856433749198914\n"
     ]
    }
   ],
   "source": [
    "x.set([1,0])\n",
    "y.set(1)\n",
    "loss_value = loss.value() # this performs a forward through the network.\n",
    "print(\"the loss before step is:\",loss_value)\n",
    "\n",
    "# now do an optimization step\n",
    "loss.backward()  # compute the gradients\n",
    "trainer.update()\n",
    "\n",
    "# see how it affected the loss:\n",
    "loss_value = loss.value(recalculate=True) # recalculate=True means \"don't use precomputed value\"\n",
    "print(\"the loss after step is:\",loss_value)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The optimization step indeed made the loss decrease. We now need to run this in a loop.\n",
    "To this end, we will create a `training set`, and iterate over it.\n",
    "\n",
    "For the xor problem, the training instances are easy to create."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_xor_instances(num_rounds=2000):\n",
    "    questions = []\n",
    "    answers = []\n",
    "    for round in range(num_rounds):\n",
    "        for x1 in 0,1:\n",
    "            for x2 in 0,1:\n",
    "                answer = 0 if x1==x2 else 1\n",
    "                questions.append((x1,x2))\n",
    "                answers.append(answer)\n",
    "    return questions, answers \n",
    "\n",
    "questions, answers = create_xor_instances()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now feed each question / answer pair to the network, and try to minimize the loss.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average loss is: 0.7301900261640548\n",
      "average loss is: 0.7117043241858483\n",
      "average loss is: 0.6865043716629347\n",
      "average loss is: 0.6486480497568846\n",
      "average loss is: 0.5961489198803902\n",
      "average loss is: 0.5363389208291968\n",
      "average loss is: 0.4800921744853258\n",
      "average loss is: 0.4316547167254612\n",
      "average loss is: 0.39094374931520887\n",
      "average loss is: 0.35675006832927464\n",
      "average loss is: 0.3278251909870993\n",
      "average loss is: 0.30312755185179413\n",
      "average loss is: 0.2818366446174108\n",
      "average loss is: 0.2633153002815587\n",
      "average loss is: 0.2470681765327851\n",
      "average loss is: 0.23270742912660353\n",
      "average loss is: 0.2199265043853837\n",
      "average loss is: 0.20848063852793228\n",
      "average loss is: 0.19817242979072033\n",
      "average loss is: 0.18884113303828054\n",
      "average loss is: 0.18035465603760842\n",
      "average loss is: 0.17260352481601082\n",
      "average loss is: 0.1654962877224645\n",
      "average loss is: 0.1589559833255286\n",
      "average loss is: 0.15291739877592772\n",
      "average loss is: 0.14732492341812198\n",
      "average loss is: 0.14213085135462245\n",
      "average loss is: 0.13729403161759754\n",
      "average loss is: 0.13277878321490474\n",
      "average loss is: 0.12855401825974697\n",
      "average loss is: 0.12459252774257273\n",
      "average loss is: 0.12087039561367419\n",
      "average loss is: 0.11736651676206031\n",
      "average loss is: 0.11406219601699644\n",
      "average loss is: 0.11094081379626212\n",
      "average loss is: 0.10798754755103598\n",
      "average loss is: 0.10518913371828259\n",
      "average loss is: 0.10253366940838628\n",
      "average loss is: 0.10001044190178315\n",
      "average loss is: 0.09760978312254884\n",
      "average loss is: 0.09532294639377718\n",
      "average loss is: 0.09314199849195401\n",
      "average loss is: 0.0910597273951862\n",
      "average loss is: 0.08906956217108845\n",
      "average loss is: 0.08716550341245925\n",
      "average loss is: 0.08534206202271415\n",
      "average loss is: 0.08359420659250896\n",
      "average loss is: 0.08191731647852672\n",
      "average loss is: 0.08030714023444915\n",
      "average loss is: 0.07875976019569207\n",
      "average loss is: 0.0772715599823734\n",
      "average loss is: 0.07583919591308445\n",
      "average loss is: 0.07445957204553229\n",
      "average loss is: 0.07312981744738796\n",
      "average loss is: 0.07184726620641198\n",
      "average loss is: 0.07060943942273817\n",
      "average loss is: 0.06941402890125142\n",
      "average loss is: 0.06825888305458899\n",
      "average loss is: 0.06714199365698732\n",
      "average loss is: 0.06606148393452167\n",
      "average loss is: 0.06501559805323477\n",
      "average loss is: 0.0640026917649741\n",
      "average loss is: 0.06302122345515748\n",
      "average loss is: 0.06206974616015941\n",
      "average loss is: 0.06114690068011316\n",
      "average loss is: 0.060251408738302856\n",
      "average loss is: 0.05938206722341311\n",
      "average loss is: 0.05853774277446278\n",
      "average loss is: 0.057717366610177914\n",
      "average loss is: 0.05691992998316917\n",
      "average loss is: 0.056144480362147565\n",
      "average loss is: 0.05539011717681812\n",
      "average loss is: 0.05465598844808259\n",
      "average loss is: 0.053941287759839356\n",
      "average loss is: 0.053245250818133354\n",
      "average loss is: 0.052567153106946006\n",
      "average loss is: 0.051906307245069484\n",
      "average loss is: 0.05126206048185943\n",
      "average loss is: 0.05063379273556108\n",
      "average loss is: 0.05002091444953112\n"
     ]
    }
   ],
   "source": [
    "total_loss = 0\n",
    "seen_instances = 0\n",
    "for question, answer in zip(questions, answers):\n",
    "    x.set(question)\n",
    "    y.set(answer)\n",
    "    seen_instances += 1\n",
    "    total_loss += loss.value()\n",
    "    loss.backward()\n",
    "    trainer.update()\n",
    "    if (seen_instances > 1 and seen_instances % 100 == 0):\n",
    "        print(\"average loss is:\",total_loss / seen_instances)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Our network is now trained. Let's verify that it indeed learned the xor function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0,1 0.998213529586792\n",
      "1,0 0.9983397722244263\n",
      "0,0 0.0007906468817964196\n",
      "1,1 0.0021107089705765247\n"
     ]
    }
   ],
   "source": [
    "x.set([0,1])\n",
    "print(\"0,1\",output.value())\n",
    "\n",
    "x.set([1,0])\n",
    "print(\"1,0\",output.value())\n",
    "\n",
    "x.set([0,0])\n",
    "print(\"0,0\",output.value())\n",
    "\n",
    "x.set([1,1])\n",
    "print(\"1,1\",output.value())\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In case we are curious about the parameter values, we can query them:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 2.85107112,  2.83952975],\n",
       "       [-3.29001093,  2.51486993],\n",
       "       [-1.92002058, -1.90759397],\n",
       "       [-1.4002918 , -1.43046546],\n",
       "       [ 0.10682328, -0.89503163],\n",
       "       [-1.78532696,  2.70406151],\n",
       "       [ 1.20831835, -0.47131985],\n",
       "       [ 0.92750639, -2.0729847 ]])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "W.value()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 4.12795115,  4.78487778, -2.21212292,  2.85010242,  1.01012611,\n",
       "        -3.31246257, -1.09919119,  2.19970202]])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "V.value()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-0.937517523765564,\n",
       " -1.1073024272918701,\n",
       " 0.33602145314216614,\n",
       " 2.16909122467041,\n",
       " 0.17579713463783264,\n",
       " 0.7122746706008911,\n",
       " 0.0978747308254242,\n",
       " -0.16976511478424072]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b.value()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### To summarize\n",
    "Here is a complete program:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average loss is: 0.7205019667744637\n",
      "average loss is: 0.6994020892679691\n",
      "average loss is: 0.6667007360855738\n",
      "average loss is: 0.6162097529321909\n",
      "average loss is: 0.5529091520011425\n",
      "average loss is: 0.4909045885503292\n",
      "average loss is: 0.43754800063158783\n",
      "average loss is: 0.39328361499123277\n",
      "average loss is: 0.3566404625069764\n",
      "average loss is: 0.3260498847439885\n",
      "average loss is: 0.3002264795401557\n",
      "average loss is: 0.27818110284395514\n",
      "average loss is: 0.2591624129871623\n",
      "average loss is: 0.242597631290555\n",
      "average loss is: 0.22804588572122156\n",
      "average loss is: 0.21516385969764087\n",
      "average loss is: 0.20368098794153947\n",
      "average loss is: 0.1933815449432263\n",
      "average loss is: 0.18409163802911185\n",
      "average loss is: 0.17566965313232505\n",
      "average loss is: 0.16799916441020157\n",
      "average loss is: 0.16098361631537872\n",
      "average loss is: 0.1545422787496658\n",
      "average loss is: 0.14860714952888276\n",
      "average loss is: 0.1431205491213128\n",
      "average loss is: 0.13803323951316998\n",
      "average loss is: 0.13330293813159827\n",
      "average loss is: 0.12889313311440803\n",
      "average loss is: 0.1247721328136736\n",
      "average loss is: 0.12091229626559652\n",
      "average loss is: 0.11728940626944326\n",
      "average loss is: 0.11388215588252933\n",
      "average loss is: 0.11067172380397096\n",
      "average loss is: 0.10764142127125524\n",
      "average loss is: 0.10477639951383962\n",
      "average loss is: 0.10206340225982584\n",
      "average loss is: 0.09949055765566693\n",
      "average loss is: 0.09704720211006995\n",
      "average loss is: 0.09472372988238931\n",
      "average loss is: 0.09251146528020035\n",
      "average loss is: 0.09040255263388135\n",
      "average loss is: 0.08838986149038343\n",
      "average loss is: 0.08646690461080694\n",
      "average loss is: 0.08462776689538838\n",
      "average loss is: 0.08286704372017023\n",
      "average loss is: 0.08117978683943637\n",
      "average loss is: 0.07956145741833136\n",
      "average loss is: 0.07800788381944584\n",
      "average loss is: 0.07651522612707613\n",
      "average loss is: 0.07507994277025573\n",
      "average loss is: 0.07369876249170607\n",
      "average loss is: 0.07236865903304603\n",
      "average loss is: 0.07108682857070751\n",
      "average loss is: 0.06985066948984577\n",
      "average loss is: 0.06865776468049312\n",
      "average loss is: 0.06750586506561376\n",
      "average loss is: 0.06639287554648737\n",
      "average loss is: 0.06531684178260862\n",
      "average loss is: 0.06427593874778614\n",
      "average loss is: 0.06326845997859103\n",
      "average loss is: 0.062292808323533684\n",
      "average loss is: 0.061347487150436086\n",
      "average loss is: 0.06043109254734147\n",
      "average loss is: 0.059542306310031566\n",
      "average loss is: 0.05867988926305686\n",
      "average loss is: 0.057842675803783064\n",
      "average loss is: 0.057029568297517444\n",
      "average loss is: 0.056239532018431314\n",
      "average loss is: 0.05547159101134942\n",
      "average loss is: 0.05472482365616763\n",
      "average loss is: 0.05399835920601617\n",
      "average loss is: 0.05329137419280111\n",
      "average loss is: 0.052603089143243006\n",
      "average loss is: 0.05193276596526974\n",
      "average loss is: 0.051279704820535454\n",
      "average loss is: 0.05064324217525274\n",
      "average loss is: 0.050022748128961556\n",
      "average loss is: 0.04941762432470941\n",
      "average loss is: 0.048827302071201034\n",
      "average loss is: 0.048251240481033165\n"
     ]
    }
   ],
   "source": [
    "# define the parameters\n",
    "m = dy.ParameterCollection()\n",
    "W = m.add_parameters((8,2))\n",
    "V = m.add_parameters((1,8))\n",
    "b = m.add_parameters((8))\n",
    "\n",
    "# renew the computation graph\n",
    "dy.renew_cg()\n",
    "\n",
    "# create the network\n",
    "x = dy.vecInput(2) # an input vector of size 2.\n",
    "output = dy.logistic(V*(dy.tanh((W*x)+b)))\n",
    "# define the loss with respect to an output y.\n",
    "y = dy.scalarInput(0) # this will hold the correct answer\n",
    "loss = dy.binary_log_loss(output, y)\n",
    "\n",
    "# create training instances\n",
    "def create_xor_instances(num_rounds=2000):\n",
    "    questions = []\n",
    "    answers = []\n",
    "    for round in range(num_rounds):\n",
    "        for x1 in 0,1:\n",
    "            for x2 in 0,1:\n",
    "                answer = 0 if x1==x2 else 1\n",
    "                questions.append((x1,x2))\n",
    "                answers.append(answer)\n",
    "    return questions, answers \n",
    "\n",
    "questions, answers = create_xor_instances()\n",
    "\n",
    "# train the network\n",
    "trainer = dy.SimpleSGDTrainer(m)\n",
    "\n",
    "total_loss = 0\n",
    "seen_instances = 0\n",
    "for question, answer in zip(questions, answers):\n",
    "    x.set(question)\n",
    "    y.set(answer)\n",
    "    seen_instances += 1\n",
    "    total_loss += loss.value()\n",
    "    loss.backward()\n",
    "    trainer.update()\n",
    "    if (seen_instances > 1 and seen_instances % 100 == 0):\n",
    "        print(\"average loss is:\",total_loss / seen_instances)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dynamic Networks\n",
    "\n",
    "Dynamic networks are very similar to static ones, but instead of creating the network once and then calling \"set\" in each training example to change the inputs, we just create a new network for each training example.\n",
    "\n",
    "We present an example below. While the value of this may not be clear in the `xor` example, the dynamic approach\n",
    "is very convenient for networks for which the structure is not fixed, such as recurrent or recursive networks."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "average loss is: 0.7242387273907661\n",
      "average loss is: 0.6815587529540061\n",
      "average loss is: 0.6179609213272731\n",
      "average loss is: 0.5381991682946682\n",
      "average loss is: 0.46485630394518374\n",
      "average loss is: 0.4053867908070485\n",
      "average loss is: 0.3582047475874424\n",
      "average loss is: 0.32044148206710815\n",
      "average loss is: 0.2897336838249531\n",
      "average loss is: 0.2643519359687343\n",
      "average loss is: 0.24305510523898358\n",
      "average loss is: 0.22494598247110845\n",
      "average loss is: 0.20936599294225183\n",
      "average loss is: 0.195823337073837\n",
      "average loss is: 0.18394447944127024\n",
      "average loss is: 0.17344117014989024\n",
      "average loss is: 0.16408769504578016\n",
      "average loss is: 0.15570493978655173\n",
      "average loss is: 0.14814903329935317\n",
      "average loss is: 0.1413031330066733\n",
      "average loss is: 0.1350713880714916\n",
      "average loss is: 0.12937444295332004\n",
      "average loss is: 0.12414604435290169\n",
      "average loss is: 0.11933044975992137\n",
      "average loss is: 0.11488042589090765\n",
      "average loss is: 0.11075568636440529\n",
      "average loss is: 0.10692166014971143\n",
      "average loss is: 0.10334851395771173\n",
      "average loss is: 0.10001036719587664\n",
      "average loss is: 0.09688465872919187\n",
      "average loss is: 0.0939516312917394\n",
      "average loss is: 0.09119390803352871\n",
      "average loss is: 0.08859614507605632\n",
      "average loss is: 0.08614474194569459\n",
      "average loss is: 0.08382760104763189\n",
      "average loss is: 0.0816339253942715\n",
      "average loss is: 0.07955404841146003\n",
      "average loss is: 0.07757928909470425\n",
      "average loss is: 0.07570182968754895\n",
      "average loss is: 0.07391461094305851\n",
      "average loss is: 0.07221124223728732\n",
      "average loss is: 0.070585923835613\n",
      "average loss is: 0.06903338058765025\n",
      "average loss is: 0.06754880329556975\n",
      "average loss is: 0.06612779856276595\n",
      "average loss is: 0.06476634448308133\n",
      "average loss is: 0.06346075249892819\n",
      "average loss is: 0.062207633629247236\n",
      "average loss is: 0.06100386795333625\n",
      "average loss is: 0.059846579102938995\n",
      "average loss is: 0.058733110335135064\n",
      "average loss is: 0.05766100448007749\n",
      "average loss is: 0.05662798507044636\n",
      "average loss is: 0.05563194011197926\n",
      "average loss is: 0.054670907723167066\n",
      "average loss is: 0.05374306264720092\n",
      "average loss is: 0.05284670477963804\n",
      "average loss is: 0.051980248590979466\n",
      "average loss is: 0.05114221337111272\n",
      "average loss is: 0.050331215119435606\n",
      "average loss is: 0.04954595835507298\n",
      "average loss is: 0.04878522939514369\n",
      "average loss is: 0.04804788986208021\n",
      "average loss is: 0.04733287107374053\n",
      "average loss is: 0.04663916844668655\n",
      "average loss is: 0.04596583708531618\n",
      "average loss is: 0.045311987426708826\n",
      "average loss is: 0.04467678093440447\n",
      "average loss is: 0.04405942677290759\n",
      "average loss is: 0.04345917822181114\n",
      "average loss is: 0.04287532994617105\n",
      "average loss is: 0.04230721487954724\n",
      "average loss is: 0.04175420160314438\n",
      "average loss is: 0.041215692378306835\n",
      "average loss is: 0.04069112060347883\n",
      "average loss is: 0.040179948867359934\n",
      "average loss is: 0.03968166718186883\n",
      "average loss is: 0.0391957911709622\n",
      "average loss is: 0.03872186044427048\n",
      "average loss is: 0.03825943737498892\n"
     ]
    }
   ],
   "source": [
    "import dynet as dy\n",
    "# create training instances, as before\n",
    "def create_xor_instances(num_rounds=2000):\n",
    "    questions = []\n",
    "    answers = []\n",
    "    for round in range(num_rounds):\n",
    "        for x1 in 0,1:\n",
    "            for x2 in 0,1:\n",
    "                answer = 0 if x1==x2 else 1\n",
    "                questions.append((x1,x2))\n",
    "                answers.append(answer)\n",
    "    return questions, answers \n",
    "\n",
    "questions, answers = create_xor_instances()\n",
    "\n",
    "# create a network for the xor problem given input and output\n",
    "def create_xor_network(W, V, b, inputs, expected_answer):\n",
    "    dy.renew_cg() # new computation graph\n",
    "    x = dy.vecInput(len(inputs))\n",
    "    x.set(inputs)\n",
    "    y = dy.scalarInput(expected_answer)\n",
    "    output = dy.logistic(V*(dy.tanh((W*x)+b)))\n",
    "    loss =  dy.binary_log_loss(output, y)\n",
    "    return loss\n",
    "\n",
    "m2 = dy.ParameterCollection()\n",
    "W = m2.add_parameters((8,2))\n",
    "V = m2.add_parameters((1,8))\n",
    "b = m2.add_parameters((8))\n",
    "trainer = dy.SimpleSGDTrainer(m2)\n",
    "\n",
    "seen_instances = 0\n",
    "total_loss = 0\n",
    "for question, answer in zip(questions, answers):\n",
    "    loss = create_xor_network(W, V, b, question, answer)\n",
    "    seen_instances += 1\n",
    "    total_loss += loss.value()\n",
    "    loss.backward()\n",
    "    trainer.update()\n",
    "    if (seen_instances > 1 and seen_instances % 100 == 0):\n",
    "        print(\"average loss is:\",total_loss / seen_instances)\n",
    "\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python3",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
